{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import torch, os, LocalDataset, transforms, np, get_class, num_classes, preprocessing, Image, m, s\n",
    "from config import *\n",
    "\n",
    "from torch import nn\n",
    "from torch.optim import SGD\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models import resnet, vgg\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from numpy import unravel_index\n",
    "import gc\n",
    "import argparse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean=m\n",
    "std_dev=s\n",
    "\n",
    "transform = transforms.Compose([transforms.Resize((224,224)),\n",
    "                                    transforms.ToTensor(),\n",
    "                                    transforms.Normalize(mean, std_dev)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "results/resnet152/resnet152.pt\n"
     ]
    }
   ],
   "source": [
    "classes = {\"num_classes\": len(num_classes)}\n",
    "resnet152_model = resnet.resnet152(pretrained=False, **classes)\n",
    "model_name=\"resnet152\"\n",
    "model=resnet152_model\n",
    "\n",
    "print (str(RESULTS_PATH) + \"/\" + str(model_name) + \"/\" + str(model_name) + \".pt\")\n",
    "model.load_state_dict(torch.load(str(RESULTS_PATH) + \"/\" + str(model_name) + \"/\" + str(model_name) + \".pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_sample(image_path, model=model, model_name=model_name):\n",
    "    im = Image.open(image_path).convert(\"RGB\")\n",
    "    im = transform(im)\n",
    "\n",
    "    if USE_CUDA and cuda_available:\n",
    "        model = model.cuda()\n",
    "    model.eval()\n",
    "\n",
    "    x = Variable(im.unsqueeze(0))\n",
    "\n",
    "    if USE_CUDA and cuda_available:\n",
    "        x = x.cuda()\n",
    "        pred = model(x).data.cuda().cpu().numpy().copy()\n",
    "    else:\n",
    "        pred = model(x).data.numpy().copy()\n",
    "\n",
    "    #print (pred)\n",
    "\n",
    "    idx_max_pred = np.argmax(pred)\n",
    "    idx_classes = idx_max_pred % classes[\"num_classes\"]\n",
    "    #print(get_class(idx_classes))\n",
    "    return get_class(idx_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1_ford_explorer'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/fordx.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1_ford_explorer'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/ford.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3_honda_civic'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/honda.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3_honda_civic'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/hondax.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3_honda_civic'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/honday.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2_nissan_altima'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/nissan.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2_nissan_altima'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/nissan2.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2_nissan_altima'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/nissan_1.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2_nissan_altima'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/nissanx.jpg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test con immagini prese dal web"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1_ford_explorer'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/ford_web.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3_honda_civic'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/honda_web.jpeg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2_nissan_altima'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/nissan_web.jpeg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1_ford_explorer'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sample(\"sample/ford_explor.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "ford = glob(\"test_images/Ford_Explorer/*\")\n",
    "nissan = glob(\"test_images/Nissan_Altima/*\")\n",
    "honda = glob(\"test_images/Honda_Civic/*\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "ford_results = \"\"\n",
    "for f in ford:\n",
    "    ford_results += f+\",\"+test_sample(f)+\"\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "62\n"
     ]
    }
   ],
   "source": [
    "nissan_results = \"\"\n",
    "count = 0\n",
    "for f in nissan:\n",
    "    out = test_sample(f)\n",
    "    if out == \"2_nissan_altima\":\n",
    "        count+=1\n",
    "    nissan_results += f+\",\"+out+\"\\n\"\n",
    "print (count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "honda_results = \"\"\n",
    "for f in honda:\n",
    "    honda_results += f+\",\"+test_sample(f)+\"\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open(\"ford.csv\", \"w+\")\n",
    "file.write(ford_results)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open(\"nissan.csv\", \"w+\")\n",
    "file.write(nissan_results)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open(\"honda.csv\", \"w+\")\n",
    "file.write(honda_results)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
